{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HxxOWz9dpsYj"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/drive/1Z_FEocd1xLg1KEh8FCs8ho61H5ItJwHA?usp=sharing\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n",
        "\n",
        "###\tWhat is Re-ranking?\n",
        "\n",
        "Re-ranking in RAG is a critical process that refines and reorders the initially retrieved information before it's fed into a generative AI model. It acts as a smart filter, ensuring that the most relevant and high-quality content is prioritized for the generation task.\n",
        "\n",
        "### Key aspects:\n",
        "1. Relevance optimization: Improves the quality of information used by the LLM.\n",
        "2. Intelligent sorting: Uses advanced algorithms to reassess and reorder retrieved passages.\n",
        "3. Context consideration: Takes into account the query intent and user context.\n",
        "4. Integration point: Sits between retrieval and generation components in the RAG pipeline.\n",
        "\n",
        "By effectively re-ranking retrieved information, RAG systems can significantly enhance the accuracy, relevance, and overall quality of the generated AI responses.\n",
        "\n",
        "\n",
        "### Setup\n",
        "\n",
        "1. **[LLM](https://groq.com/):** Groq's free Open source LLM endpoints([Groq API Key](https://console.groq.com/keys))\n",
        "2. **[Vector Store](https://www.pinecone.io/learn/vector-database/):** [ChromaDB](https://www.trychroma.com/)\n",
        "3. **[Embedding Model](https://qdrant.tech/articles/what-are-embeddings/):** [nomic-embed-text-v1.5](https://www.nomic.ai/blog/posts/nomic-embed-text-v1)\n",
        "4. **[LLM Framework](https://python.langchain.com/v0.2/docs/introduction/):** LangChain\n",
        "5. **[Huggingface API Key](https://huggingface.co/settings/tokens)**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y3axTI0sp5Hg"
      },
      "source": [
        "# Install required libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ShxTNxM5gqtr",
        "outputId": "8af9c2e6-a725-48da-c61a-978827dae146"
      },
      "outputs": [],
      "source": [
        "!pip install -q -U \\\n",
        "     Sentence-transformers==3.0.1 \\\n",
        "     langchain==0.3.19 \\\n",
        "     langchain-groq==0.2.4 \\\n",
        "     langchain-community==0.3.18 \\\n",
        "     langchain-huggingface==0.1.2 \\\n",
        "     einops==0.8.1 \\\n",
        "     chromadb==0.6.3 \\\n",
        "     flashrank==0.2.10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9jJ1vqs-p_Zx"
      },
      "source": [
        "### Import related libraries related to Langchain, HuggingfaceEmbedding"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RL-3LsYogoH5",
        "outputId": "e484a110-13d3-4a05-eb3e-c0abf47f88cf"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:langchain_community.utils.user_agent:USER_AGENT environment variable not set, consider setting it to identify your requests.\n"
          ]
        }
      ],
      "source": [
        "# Import Libraries\n",
        "from langchain_huggingface import HuggingFaceEmbeddings\n",
        "from langchain_groq import ChatGroq\n",
        "\n",
        "from langchain.vectorstores import Chroma\n",
        "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
        "from langchain.text_splitter import CharacterTextSplitter\n",
        "from langchain.prompts import ChatPromptTemplate\n",
        "from langchain.document_loaders import WebBaseLoader\n",
        "from langchain.retrievers import ContextualCompressionRetriever\n",
        "from langchain.retrievers.document_compressors import (\n",
        "    LLMChainExtractor,\n",
        "    EmbeddingsFilter,\n",
        ")\n",
        "from langchain.retrievers.document_compressors.flashrank_rerank import FlashrankRerank\n",
        "# Import the RetrievalQA chain for question-answering tasks\n",
        "from langchain.chains import RetrievalQA"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "GT55z5AkhyOW"
      },
      "outputs": [],
      "source": [
        "import getpass\n",
        "import os"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F6UeDlrgqI2A"
      },
      "source": [
        "#### Provide a Groq API key. You can create one to access free open-source models at the following link.\n",
        "\n",
        "[Groq API Creation Link](https://console.groq.com/keys)\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yobvrD3glfd4",
        "outputId": "c204236e-92a0-46cb-8b05-ea2bd8c872be"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "··········\n"
          ]
        }
      ],
      "source": [
        "os.environ[\"GROQ_API_KEY\"] = getpass.getpass()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S1dLpYboqeIK"
      },
      "source": [
        "### Provide Huggingface API Key. You can create Huggingface API key at following lin\n",
        "\n",
        "[Huggingface API Creation Link](https://huggingface.co/settings/tokens)\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TQ6scBGZlhpG",
        "outputId": "67d1d5e1-0c93-48d6-9967-19f016f6ced5"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "··········\n"
          ]
        }
      ],
      "source": [
        "os.environ[\"HF_TOKEN\"] = getpass.getpass()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "VC4pm6cPyhue"
      },
      "outputs": [],
      "source": [
        "# Helper function for printing docs\n",
        "def pretty_print_docs(docs):\n",
        "    # Iterate through each document and format the output\n",
        "    for i, d in enumerate(docs):\n",
        "        print(f\"{'-' * 50}\\nDocument {i + 1}:\")\n",
        "        print(f\"Content:\\n{d.page_content}\\n\")\n",
        "        print(\"Metadata:\")\n",
        "        for key, value in d.metadata.items():\n",
        "            print(f\"  {key}: {value}\")\n",
        "    print(f\"{'-' * 50}\")  # Final separator for clarity"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ywIflM5huW6I"
      },
      "source": [
        "### Step 1: Load and preprocess data code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "TysZpDqRuXM6"
      },
      "outputs": [],
      "source": [
        "def load_and_process_data(url):\n",
        "    loader = WebBaseLoader(url)\n",
        "    data = loader.load()\n",
        "\n",
        "    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)\n",
        "    chunks = text_splitter.split_documents(data)\n",
        "\n",
        "    for idx, chunk in enumerate(chunks):\n",
        "        chunk.metadata[\"id\"] = idx\n",
        "\n",
        "    return chunks"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_DCY2IUsucGu"
      },
      "source": [
        "### Step 2: Create vector store and BM25 retriever"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "Mr-l3FWKuccs"
      },
      "outputs": [],
      "source": [
        "def create_vector_store(chunks):\n",
        "    embeddings = HuggingFaceEmbeddings(model_name=\"nomic-ai/nomic-embed-text-v1.5\", model_kwargs = {'trust_remote_code': True})\n",
        "    vectorstore = Chroma.from_documents(chunks, embeddings)\n",
        "    return vectorstore"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "dSFTMURrwbss"
      },
      "outputs": [],
      "source": [
        "def reranking_rag(query, vectorstore, llm):\n",
        "    # Set up the document compressor using FlashRank\n",
        "    compressor = FlashrankRerank()\n",
        "\n",
        "    # Create a compression retriever\n",
        "    compression_retriever = ContextualCompressionRetriever(\n",
        "        base_compressor=compressor,\n",
        "        base_retriever=vectorstore.as_retriever()\n",
        "    )\n",
        "\n",
        "    # Create a RetrievalQA chain\n",
        "    chain = RetrievalQA.from_chain_type(llm=llm, retriever=compression_retriever)\n",
        "\n",
        "    # Generate response\n",
        "    response = chain.invoke(query)\n",
        "\n",
        "    return {\n",
        "        \"query\": query,\n",
        "        \"final_answer\": response[\"result\"],\n",
        "        \"retrieval_method\": \"Re-ranking with FlashRank\"\n",
        "    }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yz5W3MoNvQUh"
      },
      "source": [
        "### Step 4: Create chunk of web data to Chroma Vector Store"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "PA02yiM7vQiD",
        "outputId": "dcaecda7-0e30-40cc-bf2e-c003f72356a3"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:transformers_modules.nomic-ai.nomic-bert-2048.e5042dce39060cc34bc223455f25cf1d26db4655.modeling_hf_nomic_bert:<All keys matched successfully>\n"
          ]
        }
      ],
      "source": [
        "# Initialize the language model with specified settings (Change temeprature  and other parameters as per your requirement)\n",
        "llm = ChatGroq(\n",
        "    model=\"llama3-8b-8192\",\n",
        "    temperature=0.5\n",
        ")\n",
        "\n",
        "# Load and process data\n",
        "url = \"https://en.wikipedia.org/wiki/Artificial_intelligence\"\n",
        "chunks = load_and_process_data(url)\n",
        "\n",
        "# Create vector store\n",
        "vectorstore  = create_vector_store(chunks)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LSjCJTQQ2mZF"
      },
      "outputs": [],
      "source": [
        "# Example queries\n",
        "queries = [\n",
        "        \"What are the main applications of artificial intelligence in healthcare?\",\n",
        "        \"Explain the concept of machine learning and its relationship to AI.\",\n",
        "        \"Discuss the ethical implications of AI in decision-making processes.\"\n",
        "    ]\n",
        "\n",
        "# Run Re-ranking RAG for each query\n",
        "for query in queries:\n",
        "  print(f\"\\nQuery: {query}\")\n",
        "  result = reranking_rag(query, vectorstore, llm)\n",
        "  print(\"Final Answer:\")\n",
        "  print(result[\"final_answer\"])\n",
        "  print(\"\\nRetrieval Method:\")\n",
        "  print(result[\"retrieval_method\"])\n",
        "\n",
        "# Demonstrate retrieval and re-ranking\n",
        "demo_query = \"Explain the concept of machine learning and its relationship to AI\"\n",
        "print(f\"\\nDemonstration Query: {demo_query}\")\n",
        "\n",
        "# Retrieve documents before re-ranking\n",
        "docs_before = vectorstore.similarity_search(demo_query)\n",
        "print(\"\\nDocuments before re-ranking:\")\n",
        "pretty_print_docs(docs_before)\n",
        "\n",
        "# Retrieve and re-rank documents\n",
        "compressor = FlashrankRerank()\n",
        "compression_retriever = ContextualCompressionRetriever(\n",
        "        base_compressor=compressor,\n",
        "        base_retriever=vectorstore.as_retriever()\n",
        "    )\n",
        "docs_after = compression_retriever.get_relevant_documents(demo_query)\n",
        "print(\"\\nDocuments after re-ranking:\")\n",
        "pretty_print_docs(docs_after)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
